# define a pytensor Op for our likelihood function

import numpy as np
from scipy import stats
from scipy import constants
from scipy.integrate import quad, simpson, cumtrapz
import pymc as pm
import matplotlib.pyplot as plt
import seaborn as sns

import pytensor.tensor as pt

import arviz as az

import rcsj_model_analysis
from rcsj_model_analysis import simulating_data
import xarray as xr
from rcsj_model_analysis import io


h_bar = constants.hbar
k_B = constants.Boltzmann
elementary_charge = constants.elementary_charge


# define a pytensor Op for our likelihood function
class LogLike(pt.Op):

    """
    Specify what type of object will be passed and returned to the Op when it is
    called. In our case we will be passing it a vector of values (the parameters
    that define our model) and returning a single "scalar" value (the
    log-likelihood)
    
    """

    itypes = [pt.dvector]  # expects a vector of parameter values when called
    otypes = [pt.dscalar]  # outputs a single scalar value (the log likelihood)

    def __init__(self, loglike, failure_current):
        """
        Initialise the Op with various things that our log-likelihood function
        requires. Below are the things that are needed in this particular
        example.

        Parameters
        ----------
        loglike:
            The log-likelihood (or whatever) function we've defined
        failure_current:
            The observed failure current
        ramp_speed:
            The ramp speed of the current that the hazard function requires
        critical_current:
            The critical current of the system
        f_j0:
            The zero temperature plasma frequency of the system
        Temp:
            The temperature of the system
        """

        # add inputs as class attributes
        self.likelihood = loglike
        self.failure_current = failure_current

    def perform(self, node, inputs, outputs):
        # the method that is used when calling the Op
        
        # vector of variables
        (theta,) = inputs 

        # call the log-likelihood function
        logl = self.likelihood(theta, self.failure_current)

        outputs[0][0] = np.array(logl)  # output the log-likelihood

def hazard_func(critical_current, f_j0, Temp, I):
    """
    
    The hazard function for the switching event, can be expressed as
    Arrhenius equation based on applied current
    
    Parameters
    ----------
    I : array-like
        current points    
    critical_current : float
        constant for the critical current        
    f_j0 : float
        constant for the zero temperature plasma frequency        
    Temp : float
        Temperature of the system
    """
    
    # Define the current points
    current = I
    # Define the hazard function
    eta = current/critical_current
    freq_j = f_j0*(1-eta**2)**(1/4)
    E_J = h_bar*critical_current/(2*elementary_charge)
    Delta_U = 2*E_J*(np.sqrt(1-eta**2)-eta*np.arccos(eta))
    
    return freq_j*np.exp(-1*Delta_U/(Temp*k_B))

def hazard_new(critical_current, f_j0, Temp, U_0,I):
    """
    
    The hazard function for the switching event, can be expressed as
    Arrhenius equation based on applied current
    
    Parameters
    ----------
    I : array-like
        current points    
    critical_current : float
        constant for the critical current        
    f_j0 : float
        constant for the zero temperature plasma frequency        
    Temp : float
        Temperature of the system
    U_0 : float
        Scaling factor for the potential barrier
    """
    
    # Define the current points
    current = I
    # Define the hazard function
    eta = current/critical_current
    #mask = eta>=1
    #eta[mask]=0.01
    
    freq_j = f_j0*(1-eta**2)**(1/4)
    E_J = h_bar*critical_current/(2*elementary_charge)
    Delta_U = U_0*2*E_J*(np.sqrt(1-eta**2)-eta*np.arccos(eta))
    hazard_new = freq_j*np.exp(-1*Delta_U/(Temp*k_B))
    
    
    return hazard_new

def loglike(theta, I, numpoints=100):
    critical_current, f_j0, Temp = theta
    s = np.linspace(0, I, numpoints)
    integrals = simpson(hazard_func(critical_current, f_j0, Temp,s), s, axis=0)
    # Add a small probability density to the Hazard function to prevent pymc from crashing
    return np.sum(np.log(hazard_func(critical_current, f_j0, Temp,I)) - integrals)


# Define the log likelihood
def loglike_new(theta, I, numpoints=100):
    critical_current, f_j0, Temp, U_0, ramp_speed = theta
    s = np.linspace(0, I, numpoints)
    integrals = simpson(hazard_new(critical_current, f_j0, Temp, U_0, s)/ramp_speed, s, axis=0)
    # Add a small probability density to the Hazard function to prevent pymc from crashing
    return np.sum(np.log(hazard_new(critical_current, f_j0, Temp, U_0, I)/ramp_speed+1e-45) - integrals)


class LogLike_new(pt.Op):

    """
    Specify what type of object will be passed and returned to the Op when it is
    called. In our case we will be passing it a vector of values (the parameters
    that define our model) and returning a single "scalar" value (the
    log-likelihood)
    
    """

    itypes = [pt.dvector]  # expects a vector of parameter values when called
    otypes = [pt.dscalar]  # outputs a single scalar value (the log likelihood)

    def __init__(self, loglike, failure_current):
        """
        Initialise the Op with various things that our log-likelihood function
        requires. Below are the things that are needed in this particular
        example.

        Parameters
        ----------
        loglike:
            The log-likelihood (or whatever) function we've defined
        failure_current:
            The observed failure current
        ramp_speed:
            The ramp speed of the current that the hazard function requires
        critical_current:
            The critical current of the system
        f_j0:
            The zero temperature plasma frequency of the system
        Temp:
            The temperature of the system
        U_0:
            The scaling factor
        """

        # add inputs as class attributes
        self.likelihood = loglike
        self.failure_current = failure_current

    def perform(self, node, inputs, outputs):
        # the method that is used when calling the Op
        
        # vector of variables
        (theta,) = inputs 

        # call the log-likelihood function
        logl = self.likelihood(theta, self.failure_current)

        outputs[0][0] = np.array(logl)  # output the log-likelihood